baae8d
@@ -33,6 +33,7 @@
import javax.websocket.Endpoint;
 import javax.websocket.Extension;
 import javax.websocket.server.ServerEndpointConfig;
 
+import io.undertow.server.HttpServerExchange;
 import io.undertow.server.HttpUpgradeListener;
 import io.undertow.servlet.api.InstanceFactory;
 import io.undertow.servlet.api.InstanceHandle;
@@ -51,6 +52,7 @@
import io.undertow.websockets.jsr.handshake.JsrHybi07Handshake;
 import io.undertow.websockets.jsr.handshake.JsrHybi08Handshake;
 import io.undertow.websockets.jsr.handshake.JsrHybi13Handshake;
 import io.undertow.websockets.spi.WebSocketHttpExchange;
+import org.xnio.StreamConnection;
 
 import org.springframework.http.server.ServerHttpRequest;
 import org.springframework.http.server.ServerHttpResponse;
@@ -161,39 +163,31 @@
public class UndertowRequestUpgradeStrategy extends AbstractStandardUpgradeStrat
 
 		HttpServletRequest servletRequest = getHttpServletRequest(request);
 		HttpServletResponse servletResponse = getHttpServletResponse(response);
+
 		final ServletWebSocketHttpExchange exchange = createHttpExchange(servletRequest, servletResponse);
 		exchange.putAttachment(HandshakeUtil.PATH_PARAMS, Collections.<String, String>emptyMap());
 
 		ServerWebSocketContainer wsContainer = (ServerWebSocketContainer) getContainer(servletRequest);
 		final EndpointSessionHandler endpointSessionHandler = new EndpointSessionHandler(wsContainer);
+
 		final ConfiguredServerEndpoint configuredServerEndpoint = createConfiguredServerEndpoint(
 				selectedProtocol, selectedExtensions, endpoint, servletRequest);
+
 		final Handshake handshake = getHandshakeToUse(exchange, configuredServerEndpoint);
 
-		HttpUpgradeListener upgradeListener = (HttpUpgradeListener) Proxy.newProxyInstance(
-				getClass().getClassLoader(), new Class<?>[] {HttpUpgradeListener.class},
-				new InvocationHandler() {
-					@Override
-					public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
-						if ("handleUpgrade".equals(method.getName())) {
-							Object connection = args[0];  // currently an XNIO StreamConnection
-							Object bufferPool = ReflectionUtils.invokeMethod(getBufferPoolMethod, exchange);
-							WebSocketChannel channel = (WebSocketChannel) ReflectionUtils.invokeMethod(
-									createChannelMethod, handshake, exchange, connection, bufferPool);
-							if (peerConnections != null) {
-								peerConnections.add(channel);
-							}
-							endpointSessionHandler.onConnect(exchange, channel);
-							return null;
-						}
-						else {
-							// any java.lang.Object method: equals, hashCode, toString...
-							return ReflectionUtils.invokeMethod(method, this, args);
-						}
-					}
-				});
-
-		exchange.upgradeChannel(upgradeListener);
+		exchange.upgradeChannel(new HttpUpgradeListener() {
+			@Override
+			public void handleUpgrade(StreamConnection connection, HttpServerExchange serverExchange) {
+				Object bufferPool = ReflectionUtils.invokeMethod(getBufferPoolMethod, exchange);
+				WebSocketChannel channel = (WebSocketChannel) ReflectionUtils.invokeMethod(
+						createChannelMethod, handshake, exchange, connection, bufferPool);
+				if (peerConnections != null) {
+					peerConnections.add(channel);
+				}
+				endpointSessionHandler.onConnect(exchange, channel);
+			}
+		});
+
 		handshake.handshake(exchange);
 	}
 
